{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.simplefilter('ignore')\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "pd.set_option('max_columns', None)\n",
    "pd.set_option('max_rows', 1000)\n",
    "\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "import pickle\n",
    "import gc\n",
    "import logging\n",
    "from collections import Counter\n",
    "\n",
    "from tqdm.autonotebook import *\n",
    "\n",
    "import gensim\n",
    "from gensim.models import FastText, Word2Vec\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "import keras\n",
    "from keras import layers\n",
    "from keras import callbacks\n",
    "\n",
    "from bert4keras.snippets import sequence_padding, DataGenerator\n",
    "\n",
    "from keras_multi_head import MultiHead, MultiHeadAttention\n",
    "from keras_self_attention import SeqSelfAttention\n",
    "from keras_position_wise_feed_forward import FeedForward\n",
    "from keras_layer_normalization import LayerNormalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train = pd.read_csv('raw_data/train_set.csv', sep='\\t')\n",
    "df_test = pd.read_csv('raw_data/test_a.csv', sep='\\t')\n",
    "\n",
    "df_train['text'] = df_train['text'].apply(lambda x: list(map(lambda y: int(y), x.split())))\n",
    "df_test['text'] = df_test['text'].apply(lambda x: list(map(lambda y: int(y), x.split())))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train, df_valid = train_test_split(df_train, test_size=0.2, random_state=2020)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_classes = 14\n",
    "vocabulary_size = 7600\n",
    "\n",
    "maxlen = 256\n",
    "batch_size = 128\n",
    "embedding_dim = 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_data(df):\n",
    "    \"\"\"加载数据\"\"\"\n",
    "    D = list()\n",
    "    for _, row in df.iterrows():\n",
    "        text = row['text']\n",
    "        label = row['label']\n",
    "        D.append((text, int(label)))\n",
    "    return D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = load_data(df_train)\n",
    "valid_data = load_data(df_valid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class data_generator(DataGenerator):\n",
    "    \"\"\"数据生成器\"\"\"\n",
    "\n",
    "    def __init__(self, data, batch_size=32, buffer_size=None, random=False):\n",
    "        super().__init__(data, batch_size, buffer_size)\n",
    "        self.random = random\n",
    "\n",
    "    def __iter__(self, random=False):\n",
    "        batch_token_ids, batch_labels = [], []\n",
    "        for is_end, (text, label) in self.sample(random):\n",
    "            token_ids = text[:maxlen] if len(text) > maxlen else text + (maxlen - len(text)) * [0]\n",
    "            batch_token_ids.append(token_ids)\n",
    "            batch_labels.append([label])\n",
    "            if len(batch_token_ids) == self.batch_size or is_end:\n",
    "                batch_token_ids = sequence_padding(batch_token_ids)\n",
    "                batch_labels = sequence_padding(batch_labels)\n",
    "                yield [batch_token_ids], batch_labels\n",
    "                batch_token_ids, batch_labels = [], []\n",
    "\n",
    "    def forfit(self):\n",
    "        while True:\n",
    "            for d in self.__iter__(self.random):\n",
    "                yield d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_generator = data_generator(train_data, batch_size, random=True)\n",
    "valid_generator = data_generator(valid_data, batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_model():\n",
    "\n",
    "    inp = layers.Input(shape=(maxlen,))\n",
    "\n",
    "    emb_layer = layers.Embedding(\n",
    "        input_dim=vocabulary_size,\n",
    "        output_dim=embedding_dim,\n",
    "        input_length=maxlen\n",
    "    )(inp)\n",
    "\n",
    "    sdrop = layers.SpatialDropout1D(rate=0.2)\n",
    "\n",
    "    emb_layer = sdrop(emb_layer)\n",
    "\n",
    "    mha1 = MultiHeadAttention(head_num=16)(emb_layer)\n",
    "    mha1 = layers.Dropout(0.01)(mha1)\n",
    "    mha1 = layers.Add()([emb_layer, mha1])\n",
    "    mha1 = LayerNormalization()(mha1)\n",
    "    mha1 = layers.Dropout(0.01)(mha1)\n",
    "    mha1_ff = FeedForward(128)(mha1)\n",
    "    mha1_out = layers.Add()([mha1, mha1_ff])\n",
    "    mha1_out = LayerNormalization()(mha1_out)\n",
    "\n",
    "    mha2 = MultiHeadAttention(head_num=16)(mha1_out)\n",
    "    mha2 = layers.Dropout(0.01)(mha2)\n",
    "    mha2 = layers.Add()([mha1_out, mha2])\n",
    "    mha2 = LayerNormalization()(mha2)\n",
    "    mha2 = layers.Dropout(0.01)(mha2)\n",
    "    mha2_ff = FeedForward(128)(mha2)\n",
    "    mha2_out = layers.Add()([mha2, mha2_ff])\n",
    "    mha2_out = LayerNormalization()(mha2_out)\n",
    "    \n",
    "    lstm = layers.Bidirectional(layers.LSTM(128, return_sequences=True))(mha2_out)\n",
    "\n",
    "    avg_pool = layers.GlobalAveragePooling1D()(lstm)\n",
    "    max_pool = layers.GlobalMaxPool1D()(lstm)\n",
    "\n",
    "    x = layers.Concatenate()([avg_pool, max_pool])\n",
    "\n",
    "    x = layers.Dense(128, activation='relu')(x)\n",
    "    x = layers.BatchNormalization()(x)\n",
    "\n",
    "    x = layers.Dense(64, activation='relu')(x)\n",
    "    x = layers.BatchNormalization()(x)\n",
    "\n",
    "    x = layers.Dropout(0.2)(x)\n",
    "\n",
    "    out = layers.Dense(num_classes, activation='softmax')(x)\n",
    "    model = keras.Model(inputs=inp, outputs=out)\n",
    "    model.compile(loss='sparse_categorical_crossentropy',\n",
    "                  optimizer=keras.optimizers.Adam(1e-4),\n",
    "                  metrics=['accuracy'])\n",
    "    \n",
    "    return model\n",
    "\n",
    "model = build_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Evaluator(callbacks.Callback):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.best_val_f1 = 0.\n",
    "\n",
    "    def evaluate(self):\n",
    "        y_true, y_pred = list(), list()\n",
    "        for x, y in valid_generator:\n",
    "            y_true.append(y)\n",
    "            y_pred.append(self.model.predict(x).argmax(axis=1))\n",
    "        y_true = np.concatenate(y_true)\n",
    "        y_pred = np.concatenate(y_pred)\n",
    "        f1 = f1_score(y_true, y_pred, average='macro')\n",
    "        return f1\n",
    "\n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        val_f1 = self.evaluate()\n",
    "        if val_f1 > self.best_val_f1:\n",
    "            self.best_val_f1 = val_f1\n",
    "        logs['val_f1'] = val_f1\n",
    "        print(f'val_f1: {val_f1:.5f}, best_val_f1: {self.best_val_f1:.5f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "callbacks = [\n",
    "    Evaluator(),\n",
    "    callbacks.EarlyStopping(\n",
    "        monitor='val_accuracy', \n",
    "        mode='max',\n",
    "        patience=5, \n",
    "        verbose=1\n",
    "    ),\n",
    "    callbacks.ModelCheckpoint(\n",
    "        './models/model.h5',\n",
    "        monitor='val_f1',\n",
    "        save_weights_only=True,\n",
    "        save_best_only=True,\n",
    "        verbose=1,\n",
    "        mode='max'\n",
    "    ),\n",
    "    callbacks.ReduceLROnPlateau(\n",
    "        monitor='val_f1',\n",
    "        factor=0.1,\n",
    "        patience=2,\n",
    "        verbose=1,\n",
    "        mode='max',\n",
    "        epsilon=1e-6\n",
    "    )\n",
    "    \n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "model.fit(\n",
    "    train_generator.forfit(),\n",
    "    steps_per_epoch=len(train_generator),\n",
    "    epochs=100,\n",
    "    callbacks=callbacks,\n",
    "    validation_data=valid_generator.forfit(),\n",
    "    validation_steps=len(valid_generator)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_test['label'] = 0\n",
    "test_data = load_data(df_test)\n",
    "test_generator = data_generator(test_data, batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result = model.predict_generator(test_generator.forfit(), steps=len(test_generator))\n",
    "result = result.argmax(axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_test['label'] = result\n",
    "df_test.to_csv('submission.csv', index=False, columns=['label'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
